{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) 2024 Microsoft Corporation.\n",
    "# Licensed under the MIT License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi Index Search\n",
    "This notebook demonstrates multi-index search using the GraphRAG API.\n",
    "\n",
    "Indexes created from Wikipedia state articles for Alaska, California, DC, Maryland, NY and Washington are used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "from graphrag.api.query import (\n",
    "    multi_index_basic_search,\n",
    "    multi_index_drift_search,\n",
    "    multi_index_global_search,\n",
    "    multi_index_local_search,\n",
    ")\n",
    "from graphrag.config.create_graphrag_config import create_graphrag_config\n",
    "\n",
    "indexes = [\"alaska\", \"california\", \"dc\", \"maryland\", \"ny\", \"washington\"]\n",
    "indexes = sorted(indexes)\n",
    "\n",
    "print(indexes)\n",
    "\n",
    "vector_store_configs = {\n",
    "    index: {\n",
    "        \"type\": \"lancedb\",\n",
    "        \"db_uri\": f\"inputs/{index}/lancedb\",\n",
    "        \"container_name\": \"default\",\n",
    "        \"overwrite\": True,\n",
    "        \"index_name\": f\"{index}\",\n",
    "    }\n",
    "    for index in indexes\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_data = {\n",
    "    \"models\": {\n",
    "        \"default_chat_model\": {\n",
    "            \"model_supports_json\": True,\n",
    "            \"parallelization_num_threads\": 50,\n",
    "            \"parallelization_stagger\": 0.3,\n",
    "            \"async_mode\": \"threaded\",\n",
    "            \"type\": \"azure_openai_chat\",\n",
    "            \"model\": \"gpt-4o\",\n",
    "            \"auth_type\": \"azure_managed_identity\",\n",
    "            \"api_base\": \"<API_BASE_URL>\",\n",
    "            \"api_version\": \"2024-02-15-preview\",\n",
    "            \"deployment_name\": \"gpt-4o\",\n",
    "        },\n",
    "        \"default_embedding_model\": {\n",
    "            \"parallelization_num_threads\": 50,\n",
    "            \"parallelization_stagger\": 0.3,\n",
    "            \"async_mode\": \"threaded\",\n",
    "            \"type\": \"azure_openai_embedding\",\n",
    "            \"model\": \"text-embedding-3-large\",\n",
    "            \"auth_type\": \"azure_managed_identity\",\n",
    "            \"api_base\": \"<API_BASE_URL>\",\n",
    "            \"api_version\": \"2024-02-15-preview\",\n",
    "            \"deployment_name\": \"text-embedding-3-large\",\n",
    "        },\n",
    "    },\n",
    "    \"vector_store\": vector_store_configs,\n",
    "    \"local_search\": {\n",
    "        \"prompt\": \"prompts/local_search_system_prompt.txt\",\n",
    "        \"llm_max_tokens\": 12000,\n",
    "    },\n",
    "    \"global_search\": {\n",
    "        \"map_prompt\": \"prompts/global_search_map_system_prompt.txt\",\n",
    "        \"reduce_prompt\": \"prompts/global_search_reduce_system_prompt.txt\",\n",
    "        \"knowledge_prompt\": \"prompts/global_search_knowledge_system_prompt.txt\",\n",
    "    },\n",
    "    \"drift_search\": {\n",
    "        \"prompt\": \"prompts/drift_search_system_prompt.txt\",\n",
    "        \"reduce_prompt\": \"prompts/drift_search_reduce_prompt.txt\",\n",
    "    },\n",
    "    \"basic_search\": {\"prompt\": \"prompts/basic_search_system_prompt.txt\"},\n",
    "}\n",
    "parameters = create_graphrag_config(config_data, \".\")\n",
    "loop = asyncio.get_event_loop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multi-index Global Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n",
    "communities = [\n",
    "    pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n",
    "]\n",
    "community_reports = [\n",
    "    pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n",
    "]\n",
    "\n",
    "task = loop.create_task(\n",
    "    multi_index_global_search(\n",
    "        parameters,\n",
    "        entities,\n",
    "        communities,\n",
    "        community_reports,\n",
    "        indexes,\n",
    "        1,\n",
    "        False,\n",
    "        \"Multiple Paragraphs\",\n",
    "        False,\n",
    "        \"Describe this dataset.\",\n",
    "    )\n",
    ")\n",
    "results = await task"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Print report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(results[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Show context links back to original index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for report_id in [120, 129, 40, 16, 204, 143, 85, 122, 83]:\n",
    "    index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][  # noqa: RUF015\n",
    "        \"index_name\"\n",
    "    ]\n",
    "    index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][  # noqa: RUF015\n",
    "        \"index_id\"\n",
    "    ]\n",
    "    print(report_id, index_name, index_id)\n",
    "    index_reports = pd.read_parquet(\n",
    "        f\"inputs/{index_name}/create_final_community_reports.parquet\"\n",
    "    )\n",
    "    print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"])  # noqa: RUF015\n",
    "    print(\n",
    "        index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n",
    "            0\n",
    "        ]\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Multi-index Local Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n",
    "communities = [\n",
    "    pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n",
    "]\n",
    "community_reports = [\n",
    "    pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n",
    "]\n",
    "covariates = [\n",
    "    pd.read_parquet(f\"inputs/{index}/covariates.parquet\") for index in indexes\n",
    "]\n",
    "text_units = [\n",
    "    pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n",
    "]\n",
    "relationships = [\n",
    "    pd.read_parquet(f\"inputs/{index}/relationships.parquet\") for index in indexes\n",
    "]\n",
    "\n",
    "task = loop.create_task(\n",
    "    multi_index_local_search(\n",
    "        parameters,\n",
    "        entities,\n",
    "        communities,\n",
    "        community_reports,\n",
    "        text_units,\n",
    "        relationships,\n",
    "        covariates,\n",
    "        indexes,\n",
    "        1,\n",
    "        \"Multiple Paragraphs\",\n",
    "        False,\n",
    "        \"weather\",\n",
    "    )\n",
    ")\n",
    "results = await task"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Print report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(results[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Show context links back to original index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for report_id in [47, 213]:\n",
    "    index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][  # noqa: RUF015\n",
    "        \"index_name\"\n",
    "    ]\n",
    "    index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][  # noqa: RUF015\n",
    "        \"index_id\"\n",
    "    ]\n",
    "    print(report_id, index_name, index_id)\n",
    "    index_reports = pd.read_parquet(\n",
    "        f\"inputs/{index_name}/create_final_community_reports.parquet\"\n",
    "    )\n",
    "    print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"])  # noqa: RUF015\n",
    "    print(\n",
    "        index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n",
    "            0\n",
    "        ]\n",
    "    )\n",
    "for entity_id in [500, 502, 506, 1960, 1961, 1962]:\n",
    "    index_name = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][  # noqa: RUF015\n",
    "        \"index_name\"\n",
    "    ]\n",
    "    index_id = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][  # noqa: RUF015\n",
    "        \"index_id\"\n",
    "    ]\n",
    "    print(entity_id, index_name, index_id)\n",
    "    index_entities = pd.read_parquet(\n",
    "        f\"inputs/{index_name}/create_final_entities.parquet\"\n",
    "    )\n",
    "    print(\n",
    "        [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][  # noqa: RUF015\n",
    "            \"description\"\n",
    "        ][:100]\n",
    "    )\n",
    "    print(\n",
    "        index_entities[index_entities[\"human_readable_id\"] == int(index_id)][\n",
    "            \"description\"\n",
    "        ].to_numpy()[0][:100]\n",
    "    )\n",
    "for relationship_id in [1805, 1806]:\n",
    "    index_name = [  # noqa: RUF015\n",
    "        i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n",
    "    ][0][\"index_name\"]\n",
    "    index_id = [  # noqa: RUF015\n",
    "        i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n",
    "    ][0][\"index_id\"]\n",
    "    print(relationship_id, index_name, index_id)\n",
    "    index_relationships = pd.read_parquet(\n",
    "        f\"inputs/{index_name}/create_final_relationships.parquet\"\n",
    "    )\n",
    "    print(\n",
    "        [i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)][0][  # noqa: RUF015\n",
    "            \"description\"\n",
    "        ]\n",
    "    )\n",
    "    print(\n",
    "        index_relationships[index_relationships[\"human_readable_id\"] == int(index_id)][\n",
    "            \"description\"\n",
    "        ].to_numpy()[0]\n",
    "    )\n",
    "for claim_id in [100]:\n",
    "    index_name = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][  # noqa: RUF015\n",
    "        \"index_name\"\n",
    "    ]\n",
    "    index_id = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][  # noqa: RUF015\n",
    "        \"index_id\"\n",
    "    ]\n",
    "    print(relationship_id, index_name, index_id)\n",
    "    index_claims = pd.read_parquet(\n",
    "        f\"inputs/{index_name}/create_final_covariates.parquet\"\n",
    "    )\n",
    "    print(\n",
    "        [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][\"description\"]  # noqa: RUF015\n",
    "    )\n",
    "    print(\n",
    "        index_claims[index_claims[\"human_readable_id\"] == int(index_id)][\n",
    "            \"description\"\n",
    "        ].to_numpy()[0]\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multi-index Drift Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n",
    "communities = [\n",
    "    pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n",
    "]\n",
    "community_reports = [\n",
    "    pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n",
    "]\n",
    "text_units = [\n",
    "    pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n",
    "]\n",
    "relationships = [\n",
    "    pd.read_parquet(f\"inputs/{index}/relationships.parquet\") for index in indexes\n",
    "]\n",
    "\n",
    "task = loop.create_task(\n",
    "    multi_index_drift_search(\n",
    "        parameters,\n",
    "        entities,\n",
    "        communities,\n",
    "        community_reports,\n",
    "        text_units,\n",
    "        relationships,\n",
    "        indexes,\n",
    "        1,\n",
    "        \"Multiple Paragraphs\",\n",
    "        False,\n",
    "        \"agriculture\",\n",
    "    )\n",
    ")\n",
    "results = await task"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Print report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(results[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Show context links back to original index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for report_id in [47, 236]:\n",
    "    for question in results[1]:\n",
    "        resq = results[1][question]\n",
    "        if len(resq[\"reports\"]) == 0:\n",
    "            continue\n",
    "        if len([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)]) == 0:\n",
    "            continue\n",
    "        index_name = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][  # noqa: RUF015\n",
    "            \"index_name\"\n",
    "        ]\n",
    "        index_id = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][  # noqa: RUF015\n",
    "            \"index_id\"\n",
    "        ]\n",
    "        print(question, report_id, index_name, index_id)\n",
    "        index_reports = pd.read_parquet(\n",
    "            f\"inputs/{index_name}/create_final_community_reports.parquet\"\n",
    "        )\n",
    "        print([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"])  # noqa: RUF015\n",
    "        print(\n",
    "            index_reports[index_reports[\"community\"] == int(index_id)][\n",
    "                \"title\"\n",
    "            ].to_numpy()[0]\n",
    "        )\n",
    "        break\n",
    "for source_id in [10, 16, 19, 20, 21, 22, 24, 29, 93, 95]:\n",
    "    for question in results[1]:\n",
    "        resq = results[1][question]\n",
    "        if len(resq[\"sources\"]) == 0:\n",
    "            continue\n",
    "        if len([i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)]) == 0:\n",
    "            continue\n",
    "        index_name = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][  # noqa: RUF015\n",
    "            \"index_name\"\n",
    "        ]\n",
    "        index_id = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][  # noqa: RUF015\n",
    "            \"index_id\"\n",
    "        ]\n",
    "        print(question, source_id, index_name, index_id)\n",
    "        index_sources = pd.read_parquet(\n",
    "            f\"inputs/{index_name}/create_final_text_units.parquet\"\n",
    "        )\n",
    "        print(\n",
    "            [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][\"text\"][:250]  # noqa: RUF015\n",
    "        )\n",
    "        print(index_sources.loc[int(index_id)][\"text\"][:250])\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multi-index Basic Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_units = [\n",
    "    pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n",
    "]\n",
    "\n",
    "task = loop.create_task(\n",
    "    multi_index_basic_search(\n",
    "        parameters, text_units, indexes, False, \"industry in maryland\"\n",
    "    )\n",
    ")\n",
    "results = await task"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Print report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(results[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Show context links back to original text\n",
    "\n",
    "Note that original index name is not saved in context data for basic search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for source_id in [0, 1]:\n",
    "    print(results[1][\"sources\"][source_id][\"text\"][:250])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
